#------------------------------------------------------------------------------
# Copyright (c) 2013-2024, Nucleic Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
#------------------------------------------------------------------------------
from types import FunctionType

from bytecode import CompilerFlags
from atom.api import Atom, List, Str, Tuple, Typed

from .compiler_nodes import TemplateNode


class TemplateInstance(Atom):
    """ A class representing a template instantiation.

    Instances of this class are created by instances of Template. They
    should not be created directly by user code.

    """
    #: The template node generated by the specialization function.
    node = Typed(TemplateNode)

    def __call__(self, parent=None, **kwargs):
        """ Instantiate the list of items for the template.

        Parameters
        ----------
        parent : Object, optional
            The parent object for the generated objects.

        **kwargs
            Additional keyword arguments to apply to the returned
            items.

        Returns
        -------
        result : list
            The list of objects generated by the template.

        """
        items = self.node(parent)
        if items and kwargs:
            for item in items:
                for key, value in kwargs.items():
                    setattr(item, key, value)
        return items

    def __getattr__(self, name):
        """ Get the named attribute for the template instance.

        This method will retrieve the value from the template scope, if
        present. Otherwise, it will raise an AttributeError.

        """
        try:
            return self.node.scope[name]
        except KeyError:
            msg = "'%s' object has no attribute '%s'"
            raise AttributeError(msg % (type(self).__name__, name))


class Specialization(Atom):
    """ A class which represents the specialization of a template.

    Instances of this class are created by instances of Template.

    """
    #: The function which builds the TemplateNode.
    func = Typed(FunctionType)

    #: The specialized parameter values for the template.
    paramspec = Tuple()


class Template(Atom):
    """ A class representing a 'template' definition.

    """
    #: The name associated with the template.
    name = Str()

    #: The module name in which the template lives.
    module = Str()

    #: The list of specializations associated with the template. This
    #: list is populated by the compiler.
    specializations = List(Specialization)

    #: The cache of template instantiations.
    cache = Typed(dict, ())

    def __repr__(self):
        """ A nice repr for objects created by the `template` keyword.

        """
        return "<template '%s.%s'>" % (self.module, self.name)

    def make_paramspec(self, items):
        """ Convert the given items into a parameter specification.

        Parameters
        ----------
        items : tuple
            A tuple of parameter objects.

        Returns
        -------
        result : tuple
            A tuple of 2-tuples representing the parameter spec. Each
            2-tuple is of the form (bool, value) where the boolean
            indicates whether the value is a type.

        """
        return tuple((isinstance(item, type), item) for item in items)

    def add_specialization(self, params, func):
        """ Add a specialization to the template.

        Parameters
        ----------
        params : tuple
            A tuple specifying the parameter specializations for the
            positional arguments of the template function. A value of
            None indicates that the parameter can be of any type.

        func : FunctionType
            A function which will return a TemplateNode when invoked
            with user arguments.

        """
        paramspec = self.make_paramspec(params)
        for spec in self.specializations:
            if spec.paramspec == paramspec:
                msg = 'ambiguous template specialization for parameters: %s'
                raise TypeError(msg % (params,))
        spec = Specialization()
        spec.func = func
        spec.paramspec = paramspec
        self.specializations.append(spec)

    def get_specialization(self, args):
        """ Get the specialization for the given arguments.

        Parameters
        ----------
        args : tuple
            A tuple of arguments to match against the current template
            specializations.

        Returns
        -------
        result : Specialization or None
            The best matching specialization for the arguments, or None
            if no match could be found.

        """
        matches = []
        n_args = len(args)
        argspec = None

        for spec in self.specializations:
            # Before scoring for a match, rule out incompatible specs
            # based on the number of arguments. To few arguments is no
            # match, and too many is no match unless the specialization
            # accepts variadic arguments.
            n_params = len(spec.paramspec)
            if n_args < n_params:
                continue
            n_total = n_params + len(spec.func.__defaults__ or ())
            variadic = spec.func.__code__.co_flags & CompilerFlags.VARARGS
            if n_args > n_total and not variadic:
                continue

            # Defer creating the argpec until needed
            if argspec is None:
                argspec = self.make_paramspec(args)

            # Scoring a match is done by ranking the arguments using a
            # closeness measure. If an argument is an exact match to
            # the parameter, it gets a score of 0. If an argument is a
            # subtype of a type parameter, it gets a score equal to the
            # index of the type in the mro of the subtype. If the arg
            # is not an exact match or a subtype, the specialization is
            # not a match. If the parameter has no specialization, the
            # argument gets a score of 1 << 16, which is arbitrary but
            # large enough that it's highly unlikely to be outweighed
            # by any mro type match (1 << 16 subclasses!) and small
            # enough that max_args * (1 << 16) is less that sys.maxint.
            # The default and variadic parameters do not enter into the
            # scoring since the 'add_specialization' method will reject
            # any specialization which is ambiguous. The lowest score
            # wins and a tie will raise an exception.
            score = 0
            items = zip(argspec, spec.paramspec)
            for (a_type, arg), (p_type, param) in items:
                if arg == param:
                    continue
                if param is None:
                    score += 1 << 16
                    continue
                if p_type and a_type and param in arg.__mro__:
                    score += arg.__mro__.index(param)
                    continue
                score = -1
                break
            if score >= 0:
                matches.append((score, spec))

        if matches:
            if len(matches) == 1:
                return matches[0][1]
            matches.sort()
            score_0, match_0 = matches[0]
            score_1, match_1 = matches[1]
            if score_0 == score_1:
                msg = "ambiguous template instantiation for arguments: %s"
                raise TypeError(msg % (args,))
            return match_0

    def __call__(self, *args):
        """ Instantiate the template for the given arguments.

        Parameters
        ----------
        *args
            The arguments to use to instantiate the template.

        Returns
        -------
        result : TemplateInstance
            The instantiated template.

        """
        inst = self.cache.get(args)
        if inst is not None:
            return inst
        spec = self.get_specialization(args)
        if spec is not None:
            inst = TemplateInstance()
            inst.node = spec.func(*args)
            self.cache[args] = inst
            return inst
        msg = 'no matching template specialization for arguments: %s'
        raise TypeError(msg % (args,))
